{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Pytorch Attention COMP 5331 - Ayush",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9ebVHs71cD6E"
      },
      "source": [
        "# **Equations 7,8,9 in Fusion Sub-network**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iu9QKlKTW2aD"
      },
      "source": [
        "**Useful implementation code:**\n",
        "\n",
        "https://sigmoidal.io/implementing-additive-attention-in-pytorch/\n",
        "\n",
        "The paper we are using, uses Additive attention, also called \"Bahdanau Attention\" as implemented in a paper by Bahdanau\n",
        "\n",
        "The following implementation is based on RNN for seq2seq modeling\n",
        "\n",
        "Please go the the last code\n",
        "It should be the one that we need"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5aoWOagHfsRb"
      },
      "source": [
        "**Other good examples:**\n",
        "\n",
        "Other good examples:\n",
        "\n",
        "https://bastings.github.io/annotated_encoder_decoder/\n",
        "\n",
        "https://tomekkorbak.com/2020/06/26/implementing-attention-in-pytorch/"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Yz-9_fP5c-yE"
      },
      "source": [
        "import torch"
      ],
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w8Cv2NAEXl7B"
      },
      "source": [
        "**Encoder and Decoder**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "b_iMBdt16-ow"
      },
      "source": [
        "class AdditiveAttention_Encoder_Decoder(torch.nn.Module):\t \t \n",
        "    def __init__(self, encoder_dim=100, decoder_dim=50):\t \t \n",
        "        super().__init__()\t \t \n",
        "\n",
        "        self.encoder_dim = encoder_dim\t \t \n",
        "        self.decoder_dim = decoder_dim\t \t \n",
        "        self.v = torch.nn.Parameter(torch.rand(self.decoder_dim))\t \t \n",
        "        self.W_1 = torch.nn.Linear(self.decoder_dim, self.decoder_dim)\t \t \n",
        "        self.W_2 = torch.nn.Linear(self.encoder_dim, self.decoder_dim)\t \t \n",
        "\n",
        "    def forward(self, \t \t \n",
        "      query, # [decoder_dim]\t \t \n",
        "      values # [seq_length, encoder_dim]\t \t \n",
        "    ):\t \t \n",
        "        weights = self._get_weights(query, values) # [seq_length]\t \t \n",
        "        weights = torch.nn.functional.softmax(weights, dim=0)\t \t \n",
        "        return weights @ values # [encoder_dim]\t \t \n",
        "\n",
        "    def _get_weights(self, \t \t \n",
        "      query, # [decoder_dim]\t \t \n",
        "      values # [seq_length, encoder_dim]\t \t \n",
        "    ):\t \t \n",
        "        query = query.repeat(values.size(0), 1) # [seq_length, decoder_dim]\t \t \n",
        "        weights = self.W_1(query) + self.W_2(values) # [seq_length, decoder_dim]\t \t \n",
        "        return torch.tanh(weights) @ self.v # [seq_length]"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JeybWBrcXqVG"
      },
      "source": [
        "**Encoder only**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qMesgC41XsgO"
      },
      "source": [
        "class AdditiveAttention_Encoder(torch.nn.Module):\t \t # Please adjust the encoder dim\n",
        "    def __init__(self, encoder_dim=100):#, decoder_dim=50):\t \t \n",
        "        super().__init__()\t \t \n",
        "\n",
        "        self.encoder_dim = encoder_dim\t \t \n",
        "        #self.decoder_dim = decoder_dim\t \t \n",
        "        self.v = torch.nn.Parameter(torch.rand(self.decoder_dim))\t \t \n",
        "        #self.W_1 = torch.nn.Linear(self.decoder_dim, self.decoder_dim) # Linear transform, y = x*Trans(A) + b\t \t \n",
        "        # My hypothesis is that we might not need another bias term, since linear transform introduce that\n",
        "        # https://pytorch.org/docs/stable/generated/torch.nn.Linear.html says that linear introduces bias by default\n",
        "        self.W_2 = torch.nn.Linear(self.encoder_dim) #, self.decoder_dim) Linear transform, y = x*Trans(A) + b\t \t \n",
        "\n",
        "    def forward(self, \t \t \n",
        "      #query, # [decoder_dim]\t \t \n",
        "      values # [seq_length, encoder_dim]\t \t \n",
        "    ):\t \t \n",
        "        weights = self._get_weights(values) # [seq_length]\t \t \n",
        "        weights = torch.nn.functional.softmax(weights, dim=0)\t \t \n",
        "        return weights @ values # [encoder_dim]\t \t \n",
        "\n",
        "    def _get_weights(self, \t \t \n",
        "      #query, # [decoder_dim]\t \t \n",
        "      values # [seq_length, encoder_dim]\t \t \n",
        "    ):\t \t \n",
        "        # query = query.repeat(values.size(0), 1) # [seq_length, decoder_dim]\t \t \n",
        "        weights = self.W_2(values) # [seq_length, decoder_dim]\t \t \n",
        "        return torch.tanh(weights) @ self.v # [seq_length]"
      ],
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LxNYVCqKbzpf"
      },
      "source": [
        ""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d_ydQDF0czgU"
      },
      "source": [
        "**Execute:**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "OMGQ8fgXc5-U"
      },
      "source": [
        "u = AdditiveAttention_Encoder().forward(values) # values is encoder, supplying lis in some way [l0,l1,l2,l3,l4]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pUw4TTw2cA5M"
      },
      "source": [
        "# **Equation 10**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "i-FMSmJIciYl"
      },
      "source": [
        "weights = torch.nn.Linear(u_dim)\n",
        "p = torch.nn.functional.softmax(weights, dim=0)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LZnitUiyb8RU"
      },
      "source": [
        "# **Combine model?**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-TeM57MQfR_O"
      },
      "source": [
        "# **Loss function**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9c8rv_bueTI2"
      },
      "source": [
        "# In pytorch, loss = nn.CrossEntropyLoss()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bn9_7cFEb-40"
      },
      "source": [
        "model.compile(loss='binary_crossentropy',\n",
        "              optimizer='sgd',\n",
        "              metrics=['accuracy'])\n",
        "                   \n",
        "model.fit(X_train, y_train,epochs=4, batch_size=1, verbose=1)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xOiO8qLqenQK"
      },
      "source": [
        "# **Example of full image classifier, different from ours**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lBVLoG8Setu6"
      },
      "source": [
        "**https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py**"
      ]
    }
  ]
}